-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make_parallel for network containers #985
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the nice improvement Wei.
Looks good as is, just small comments.
Thanks,
Le
@@ -171,7 +178,8 @@ def __init__(self, | |||
stack_size, | |||
pooling_size=1, | |||
dtype=torch.float32, | |||
mode='skip'): | |||
mode='skip', | |||
name='TemporalPool'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the TemporalPool example above, it seems if pool size == 2, we always ignore every other (half) of the timesteps.
Have we considered pooling every timestep, instead of pooling every pool_size timesteps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is mode "avg" what you want?
|
||
A parallel network has ``n`` copies of network with the same structure but | ||
different independently initialized parameters. The parallel network can | ||
process a batch of the data with shape [batch_size, n, ...] using ``n`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we automatically convert data with shape [batch_size, ...] to [batch_size, n, ...]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As network can have sub-networks, doing this check for all of them can be wasteful. So it's intended to use make_parallel_input to do the conversion by the user.
alf/nest/utils.py
Outdated
@@ -104,6 +105,10 @@ def _combine_flat(self, tensors): | |||
else: | |||
return torch.cat(tensors, dim=self._dim) | |||
|
|||
def make_parallel(self, n): | |||
dim = self._dim if self._dim < 0 else self._dim + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should add comments saying that self._dim excludes batch dim and parallel dim. The current implementation seems not to ignore the batch dim.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the behavior of NestConcat to not including batch dim.
Comments added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments are very helpful. Just some more minor points.
* make_parallel for network containers * Address comments * Address further comments
This change supports creating parallel network for network containers. This is achieved by implementing make_parallel for various layers and transforming all the modules in the container.
With this change, it will be trivial to implement make_parallel for all kinds of networks if they are implemented using network containers.